{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Relay 模块级 Pass\n",
    "\n",
    "模块级 Pass `module_pass(pass_func=None, opt_level=None, name=None, required=None, traceable=False)`：\n",
    "\n",
    "当提供 `pass_func` 时，此函数返回回调函数。否则，它将充当装饰器函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tvm\n",
    "from tvm.ir.transform import module_pass\n",
    "from tvm import relay\n",
    "from tvm.relay.testing import run_infer_type"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构建 pass："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tvm.ir.transform.ModulePass"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@module_pass(opt_level=2)\n",
    "def transform(mod, ctx):\n",
    "    new_mod = tvm.IRModule()\n",
    "    x = relay.var(\"x\", shape=(5, 10), dtype=\"float32\")\n",
    "    new_mod[\"abs\"] = relay.Function([x], relay.abs(x))\n",
    "    new_mod.update(mod)\n",
    "    return new_mod\n",
    "type(transform)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以打印此变换的基本信息："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Run Module pass: transform at the optimization level 2"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "The meta data of the pass - pass name: transform, opt_level: 2, required passes: []"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transform.info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "The meta data of the pass - pass name: transform, opt_level: 2, required passes: []"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transform.pass_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "c_void_p(94634432288232)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transform.handle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里的 `transform` 函数向输入模块添加了 `abs` 函数，但它也可以是模块级的任何定制优化。创建这个 `module_pass` 之后，用户可以将它应用到任意 Relay 模块上。例如，可以构建空模块，并应用此传递来添加 `abs` 函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "#[version = \"0.0.5\"]\n",
       "def @abs(%x: Tensor[(5, 10), float32]) {\n",
       "  abs(%x)\n",
       "}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mod = tvm.IRModule()\n",
    "mod = transform(mod)\n",
    "mod"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `module_pass` 作为类装饰器\n",
    "\n",
    "`pass_func` 也可以是带有 `transform_module` 方法的类类型。这个函数将使用 `transform_module` 作为 pass 函数来创建装饰过的 `ModulePass`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "@module_pass(opt_level=1)\n",
    "class TestPipeline:\n",
    "    \"\"\"简单的测试函数，将一个参数替换为另一个参数。\"\"\"\n",
    "    def __init__(self, new_mod, replace):\n",
    "        self.new_mod = new_mod\n",
    "        self.replace = replace\n",
    "\n",
    "    def transform_module(self, mod, ctx):\n",
    "        if self.replace:\n",
    "            return self.new_mod\n",
    "        return mod"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "创建定制管道的实例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = relay.var(\"x\", shape=(10, 20))\n",
    "m1 = tvm.IRModule.from_expr(relay.Function([x], x))\n",
    "m2 = tvm.IRModule.from_expr(relay.Function([x], relay.log(x)))\n",
    "fpass = TestPipeline(m2, replace=True)\n",
    "assert fpass.info.name == \"TestPipeline\"\n",
    "mod3 = fpass(m1)\n",
    "assert mod3.same_as(m2)\n",
    "mod4 = TestPipeline(m2, replace=False)(m1)\n",
    "assert mod4.same_as(m1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "#[version = \"0.0.5\"]\n",
       "def @main(%x: Tensor[(10, 20), float32]) {\n",
       "  log(%x)\n",
       "}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mod3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "#[version = \"0.0.5\"]\n",
       "def @main(%x: Tensor[(10, 20), float32]) {\n",
       "  %x\n",
       "}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mod4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('py38': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "28558e8daad512806f5c536a1a04c119185f99f65b79002708a12162d02a79c7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
